/*____________________________________________________________________________
	Copyright (C) 2000 Networks Associates Technology, Inc.
	All rights reserved.

	$Id: pkcs7Callbacks.c,v 1.10 2001/01/25 22:11:37 jeffc Exp $
____________________________________________________________________________*/
/*
 *	Author: Michael_Elkins@NAI.com
 *	Last Edit: December 9, 1999
 */

#include "x509CMS.h"

#include "pgpHash.h"
#include "pgpErrors.h"
#include "pgpSymmetricCipher.h"
#include "pgpCBC.h"
#include "pgpDebug.h"

int
pkcs7HashCallback (
	PKIOCTET_STRING		*hashValue,	/* OUT */
	const char		*hashAlgorithm,
	const unsigned char	*tbs,
	size_t			tbsLen,
	void			*data,
	PKICONTEXT		*asnmem
)
{
    PGPError err;
    PGPHashContextRef hash;
    PGPSize hashLen;
    PGPHashAlgorithm algid;
    X509CMSCallbackData *pgpData = (X509CMSCallbackData *) data;

    if (!strcmp (SM_OID_ALG_MD5, hashAlgorithm))
	algid = kPGPHashAlgorithm_MD5;
    else if (!strcmp (SM_OID_ALG_SHA, hashAlgorithm))
	algid = kPGPHashAlgorithm_SHA;
    else
	return PKCS7_ERROR_HASH_ALG;

    err = PGPNewHashContext (pgpData->context, algid, &hash);
    if (IsPGPError (err))
	return PKCS7_ERROR_HASH_CALLBACK;
    
    err = PGPContinueHash (hash, tbs, tbsLen);
    if (IsPGPError (err))
    {
	err = PGPFreeHashContext (hash);
	return PKCS7_ERROR_HASH_CALLBACK;
    }

    err = PGPGetHashSize (hash, &hashLen);
    if (IsPGPError (err))
    {
	err = PGPFreeHashContext (hash);
	return PKCS7_ERROR_HASH_CALLBACK;
    }
    
    hashValue->len = hashLen;
    hashValue->val = PKIAlloc (asnmem->memMgr, hashLen);
    
    err = PGPFinalizeHash (hash, hashValue->val);
    if (IsPGPError (err))
    {
	PKIFree (asnmem->memMgr, hashValue->val);
	hashValue->val = NULL;
	hashValue->len = 0;
	err = PGPFreeHashContext (hash);
	return PKCS7_ERROR_HASH_CALLBACK;
    }
    
    err = PGPFreeHashContext (hash);
    
    return PKCS7_OK;
}

int
pkcs7SignCallback (
	PKIOCTET_STRING *sigValue,		/* OUT */
	const char	*hashAlgorithm,
	const char	*signatureAlgorithm,
	PKICertificate	*signerCertificate,
	unsigned char	*tbs,
	size_t		tbsLen,
	void		*data,
	PKICONTEXT	*asnmem)
{
    PGPError err;
    PGPHashContextRef hash;
    PGPHashAlgorithm algid;
    PGPPrivateKeyContextRef privkey;
    X509CMSCallbackData *pgpData = (X509CMSCallbackData *) data;
    PGPOptionListRef	pass;

	(void) signatureAlgorithm;
	(void) signerCertificate;
	
    if (!strcmp (SM_OID_ALG_MD5, hashAlgorithm))
	algid = kPGPHashAlgorithm_MD5;
    else if (!strcmp (SM_OID_ALG_SHA, hashAlgorithm))
	algid = kPGPHashAlgorithm_SHA;
    else
	return PKCS7_ERROR_HASH_ALG;

    err = PGPNewHashContext (pgpData->context, algid, &hash);
    if (IsPGPError (err))
	return PKCS7_ERROR_SIGN_CALLBACK;
    
    err = PGPContinueHash (hash, tbs, tbsLen);
    if (IsPGPError (err))
    {
	err = PGPFreeHashContext (hash);
	return PKCS7_ERROR_SIGN_CALLBACK;
    }

    if( pgpData->passphrase == NULL )
	pass = PGPONullOption( pgpData->context );
    else
	PGPCopyOptionList (pgpData->passphrase, &pass);
    err = PGPNewPrivateKeyContext (pgpData->key,
	    kPGPPublicKeyMessageFormat_X509,
	    &privkey,
	    pass,
	    PGPOLastOption (pgpData->context));
    if (IsPGPError (err))
    {
	err = PGPFreeHashContext (hash);
	return PKCS7_ERROR_SIGN_CALLBACK;
    }
    
    err = PGPGetPrivateKeyOperationSizes (privkey, NULL, NULL, &sigValue->len);
    if (IsPGPError (err))
    {
	err = PGPFreeHashContext (hash);
	err = PGPFreePrivateKeyContext (privkey);
	return PKCS7_ERROR_SIGN_CALLBACK;
    }

    sigValue->val = PKIAlloc (asnmem->memMgr, sigValue->len);

    err = PGPPrivateKeySign (privkey, hash, sigValue->val, &sigValue->len);
    if (IsPGPError (err))
    {
	PKIFree (asnmem->memMgr, sigValue->val);
	sigValue->val = NULL;
	sigValue->len = 0;
	err = PGPFreeHashContext (hash);
	err = PGPFreePrivateKeyContext (privkey);
	return PKCS7_ERROR_HASH_CALLBACK;
    }
    
    err = PGPFreePrivateKeyContext (privkey);
    
    return PKCS7_OK;
}

int
pkcs7VerifyCallback (
 	const unsigned char     *tbs,                   /* signed data */
	size_t                  tbsLen,                 /* signed data len */
	const char              *digestAlg,             /* hash alg */
	const char              *digestEncryptionAlg,   /* alg to decrypt sig */
	PKIEncryptedDigest      *signature,             /* encrypted sig */
	PKICertificate          *cert,                  /* signer cert */
	void                    *data,                  /* [IN] callback data
							   (optional) */
	PKICONTEXT		*asnmem
)
{
    X509CMSCallbackData		*pgpData = (X509CMSCallbackData *) data;
    PGPHashAlgorithm		algid;
    PGPPublicKeyContextRef	pubkey;
    PGPError			err;
    PGPHashContextRef		hash;
    PGPContextRef		context;

	(void) digestEncryptionAlg;
	(void) cert;
	(void) asnmem;
	
    if (!strcmp (SM_OID_ALG_MD5, digestAlg))
	algid = kPGPHashAlgorithm_MD5;
    else if (!strcmp (SM_OID_ALG_SHA, digestAlg))
	algid = kPGPHashAlgorithm_SHA;
    else
	return PKCS7_ERROR_HASH_ALG;

    context = PGPPeekKeyDBObjContext (pgpData->key);

    err = PGPNewHashContext (pgpData->context, algid, &hash);
    if (IsPGPError (err))
	return PKCS7_ERROR_HASH_CALLBACK;
    
    err = PGPContinueHash (hash, tbs, tbsLen);
    if (IsPGPError (err))
    {
	err = PGPFreeHashContext (hash);
	return PKCS7_ERROR_HASH_CALLBACK;
    }
    
    err = PGPNewPublicKeyContext (
	    pgpData->key,
	    kPGPPublicKeyMessageFormat_X509,
	    &pubkey);
    if (IsPGPError (err))
    {
	err = PGPFreeHashContext (hash);
	return PKCS7_ERROR_HASH_CALLBACK;
    }

    err = PGPPublicKeyVerifySignature (pubkey,
	    hash,
	    signature->val,
	    signature->len);
    if (IsPGPError (err))
    {
	err = PGPFreeHashContext (hash);
	err = PGPFreePublicKeyContext (pubkey);
	return PKCS7_ERROR_HASH_CALLBACK;
    }

    err = PGPFreePublicKeyContext (pubkey);

    return PKCS7_OK;
}

int
pkcs7EncryptCallback (
        PKIOCTET_STRING        *encryptedData, /* [OUT] encrypted data */
        PKIANY                 *encryptParam,  /* [OUT] data encryption
						   parameters (e.g.,
						   initialization vector) */
        const char              *dataEncAlg,    /* [IN] data encryption alg */
        const unsigned char     *tbe,           /* [IN] data to encrypt */
        size_t                  tbelen,         /* [IN] size of data */
        List                    *recips,        /* [IN/OUT] who to encrypt to.
                                                   The callback also returns
                                                   the encrypted session key
                                                   for each recipient in this
						   variable. */
        void                    *data,          /* [IN] user supplied data */
	PKICONTEXT		*asnmem
)
{
    PGPSymmetricCipherContextRef	cipherRef;
    PGPError			err;
    PGPCipherAlgorithm		cipherAlg;
    X509CMSCallbackData		*pgpData = (X509CMSCallbackData *) data;
    PGPMemoryMgrRef		mem = PGPPeekContextMemoryMgr (pgpData->context);
    PGPSize			blockSize, plainTextLen, lastBlockLen, keySize;
    List			*pRecip;
    int				e = 0, singleDES = 0, ret = PKCS7_ERROR_CALLBACK;

	/* must be free'd before exiting this function */
    PGPPublicKeyContextRef	pubKey;
    PGPCBCContextRef		cbcRef;
    PKIOCTET_STRING		*iv = NULL;
    PGPByte			*randomBuf = NULL;
    PGPByte			*lastBlock = NULL;
	
    if (!strcmp (dataEncAlg, SM_OID_ALG_3DES))
	cipherAlg = kPGPCipherAlgorithm_3DES;
    else if (!strcmp (dataEncAlg, SM_OID_ALG_DES)) /* DES-CBC */
    {
	/* PGPsdk does not support 1DES, but we can fake it using 3DES with
		the same key repeated */
	cipherAlg = kPGPCipherAlgorithm_3DES;
	singleDES = 1;
    }
    else
	return PKCS7_ERROR_CALLBACK; /* unsupported algorithm */

    err = PGPNewSymmetricCipherContext (pgpData->context, cipherAlg, &cipherRef);
    if (IsPGPError (err))
	return PKCS7_ERROR_CALLBACK;

    err = PGPGetSymmetricCipherSizes (cipherRef, &keySize, &blockSize);
    if (IsPGPError (err))
	goto error;

    err = PGPNewCBCContext (cipherRef, &cbcRef);
    if (IsPGPError (err))
	goto error;

    cipherRef = NULL; /* cbcRef destroys this */

    /* Generate a random session key and IV */
    randomBuf = PGPNewSecureData (mem, keySize, 0);
    err = PGPContextGetRandomBytes (pgpData->context, randomBuf, keySize);
    if (IsPGPError (err))
	goto error;

    if (singleDES)
    {
	/* if we're using 1DES, replicate the first key so that the 3DES
		effectively does 1DES */
	memcpy (randomBuf + 8, randomBuf, 8);
	memcpy (randomBuf + 16, randomBuf, 8);
    }

    iv = PKINewOCTET_STRING (asnmem);
    iv->len = blockSize;
    iv->val = PKIAlloc (asnmem->memMgr, iv->len);
    err = PGPContextGetRandomBytes (pgpData->context, iv->val, iv->len);
    if (IsPGPError (err))
	goto error;

    err = PGPInitCBC (cbcRef, randomBuf, iv->val);
    if (IsPGPError (err))
	goto error;

    /* encrypt all full blocks in one pass */
    plainTextLen = tbelen - (tbelen % blockSize);
    encryptedData->len = plainTextLen + blockSize;
    encryptedData->val = PKIAlloc (asnmem->memMgr, encryptedData->len);

    /* if we only have one block to encrypt, skip this step to avoid error */
    if (plainTextLen)
    {
	err = PGPCBCEncrypt (cbcRef, tbe, plainTextLen, encryptedData->val);
	if (IsPGPError (err))
		goto error;
    }

    /* pad the last block according to PKCS-7 */
    lastBlockLen = tbelen - plainTextLen;
    lastBlock = PGPNewSecureData (mem, blockSize, 0);
    memcpy (lastBlock, tbe + plainTextLen, lastBlockLen);
    memset (lastBlock + lastBlockLen, blockSize - lastBlockLen, blockSize - lastBlockLen);

    err = PGPCBCEncrypt (cbcRef, lastBlock, blockSize, encryptedData->val + plainTextLen);
    if (IsPGPError (err))
	goto error;

    encryptParam->len = PKISizeofOCTET_STRING (asnmem, iv, TRUE);
    encryptParam->val = PKIAlloc (asnmem->memMgr, encryptParam->len);
    PKIPackOCTET_STRING (asnmem, encryptParam->val, encryptParam->len, iv, &e);
    if (e)
	goto error;

    /* encrypt session key for each recipient */
    for (pRecip = recips; pRecip; pRecip = pRecip->next)
    {
	PGPSize decSize;
	PGPSize sigSize;
	EncryptRecipient *info = (EncryptRecipient *) pRecip->data;

	err = PGPNewPublicKeyContext ((PGPKeyDBObjRef) info->data,
		kPGPPublicKeyMessageFormat_PKCS1, /* HACK - only for RSA */
		&pubKey);
	if (IsPGPError (err))
		goto error;

	err = PGPGetPublicKeyOperationSizes (pubKey,
		&decSize,
		&info->encryptedKeyLen,
		&sigSize);
	if (IsPGPError (err))
		goto error;

	info->encryptedKey = PKIAlloc (asnmem->memMgr, info->encryptedKeyLen);

	if (singleDES)
		keySize = 8;

	err = PGPPublicKeyEncrypt (pubKey,
		randomBuf,
		keySize,
		info->encryptedKey,
		&info->encryptedKeyLen);
	if (IsPGPError (err))
		goto error;

	err = PGPFreePublicKeyContext (pubKey);
	if (IsPGPError (err))
		goto error;

	/* ??? should this be part of the PKCS7 library rather than the
	   responsibility of the callback? */
	info->algorithm = sm_OIDToString (&info->certificate->tbsCertificate.subjectPublicKeyInfo.algorithm.algorithm, asnmem);
    }

	ret = PKCS7_OK;

error:

	if (cipherRef)
		PGPFreeSymmetricCipherContext (cipherRef);
	if (cbcRef)
		PGPFreeCBCContext (cbcRef);
	if (iv)
		PKIFreeOCTET_STRING (asnmem, iv);
	if (lastBlock)
		PGPFreeData (lastBlock);
	if (randomBuf)
	    PGPFreeData (randomBuf);

    return ret;
}

int
pkcs7DecryptCallback (
        unsigned char           **msg,          /* [OUT] decrypted data */
        size_t                  *msgLen,        /* [OUT] decrypted data len */
        const char              *contentEncAlg, /* [IN] data encrypted alg */
        PKIANY			*param,     	/* [IN] data encryption
						   parameter (e.g.,
						   initialization vector) */
        PKIEncryptedContent     *content,       /* [IN] encrypted data */
        const char              *keyEncAlg,     /* [IN] key encryption alg */
        PKIEncryptedKey         *enckey,        /* [IN] encrypted key */
        PKICertificate          *cert,          /* [IN] key to decrypt with */
        void                    *data,          /* [IN] callback data
						   (optional) */
	PKICONTEXT		*asnmem
)
{
    X509CMSCallbackData			*pgpData = (X509CMSCallbackData *) data;
    PGPError				err;
    PGPSize				decmax;
    PGPByte				*encKeyData;
    PGPMemoryMgrRef			mem;
    PGPCipherAlgorithm			symKeyAlg;
    PGPSymmetricCipherContextRef	cipherRef;
    PGPCBCContextRef			cbcRef;
    PGPPrivateKeyContextRef		privKey;
    PGPSize				keySize;
    PGPSize				blockSize;
    PGPOptionListRef			pass;

    PKIOCTET_STRING *iv;
    int e = 0, singleDES = 0;
    size_t pad;

	(void) keyEncAlg;
	(void) cert;
	
    *msg = NULL;
    *msgLen = 0;

    if (!strcmp (contentEncAlg, SM_OID_ALG_3DES))
	symKeyAlg = kPGPCipherAlgorithm_3DES;
    else if (!strcmp (contentEncAlg, SM_OID_ALG_DES))
    {
	symKeyAlg = kPGPCipherAlgorithm_3DES;
	singleDES = 1;
    }
    else
    {
	/* unsupported algorithm */
	return PKCS7_ERROR_CALLBACK;
    }

    /* decrypt the encrypted session key with our private key */
    if( pgpData->passphrase == NULL )
	pass = PGPONullOption( pgpData->context );
    else
	PGPCopyOptionList (pgpData->passphrase, &pass);
    err = PGPNewPrivateKeyContext (pgpData->key,
	    kPGPPublicKeyMessageFormat_PKCS1,
	    &privKey,
	    pass,
	    PGPOLastOption (pgpData->context));
    if (IsPGPError (err))
	return PKCS7_ERROR_CALLBACK;

    err = PGPGetPrivateKeyOperationSizes (privKey, &decmax, NULL, NULL);
    if (IsPGPError (err))
    {
	err = PGPFreePrivateKeyContext (privKey);
	return PKCS7_ERROR_CALLBACK;
    }

    mem = PGPPeekContextMemoryMgr (pgpData->context);

    encKeyData = PGPNewSecureData (mem, decmax, 0);

    err = PGPPrivateKeyDecrypt (privKey,
	    enckey->val,
	    enckey->len,
	    encKeyData,
	    &decmax);
    if (IsPGPError (err))
    {
	PGPFreeData (encKeyData);
	err = PGPFreePrivateKeyContext (privKey);
	return PKCS7_ERROR_CALLBACK;
    }

    err = PGPFreePrivateKeyContext (privKey);

    /* unpack the parameters for the decryption (IV) */
    PKIUnpackOCTET_STRING (asnmem, &iv, param->val, param->len, &e);
    if (e)
    {
	PGPFreeData (encKeyData);
	return PKCS7_ERROR_CALLBACK;
    }

    /* now decrypt the message data with the extracted key */
    err = PGPNewSymmetricCipherContext (pgpData->context, symKeyAlg, &cipherRef);
    if (IsPGPError (err))
    {
	PGPFreeData (encKeyData);
	PKIFreeOCTET_STRING (asnmem, iv);
	return PKCS7_ERROR_CALLBACK;
    }

    err = PGPGetSymmetricCipherSizes (cipherRef, &keySize, &blockSize);
    if (IsPGPError (err))
    {
	PGPFreeData (encKeyData);
	PKIFreeOCTET_STRING (asnmem, iv);
	err = PGPFreeSymmetricCipherContext (cipherRef);
	return PKCS7_ERROR_CALLBACK;
    }

    err = PGPNewCBCContext (cipherRef, &cbcRef);
    if (IsPGPError (err))
    {
	PGPFreeData (encKeyData);
	err = PGPFreeSymmetricCipherContext (cipherRef);
	PKIFreeOCTET_STRING (asnmem, iv);
	return PKCS7_ERROR_CALLBACK;
    }

    if (singleDES)
    {
	/* convert 1DES key into 3DES key */
	PGPByte *key3;

	pgpAssert (decmax == 8);
	pgpAssert (keySize == 3 * decmax);
	key3 = PGPNewSecureData (mem, keySize, 0);
	memcpy (key3, encKeyData, 8);
	memcpy (key3 + 8, encKeyData, 8);
	memcpy (key3 + 16, encKeyData, 8);
	PGPFreeData (encKeyData);
	encKeyData = key3;
    }

    err = PGPInitCBC (cbcRef, encKeyData, iv->val);
    if (IsPGPError (err))
    {
	err = PGPFreeCBCContext (cbcRef);
	PKIFreeOCTET_STRING (asnmem, iv);
	PGPFreeData (encKeyData);
	return PKCS7_ERROR_CALLBACK;
    }
    PKIFreeOCTET_STRING (asnmem, iv);
    PGPFreeData (encKeyData);

    *msgLen = content->len;
    *msg = PKIAlloc (asnmem->memMgr, content->len);
    err = PGPCBCDecrypt (cbcRef, content->val, content->len, *msg);
    if (IsPGPError (err))
    {
	PKIFree (asnmem->memMgr, *msg);
	*msg = NULL;
	*msgLen = 0;
	err = PGPFreeCBCContext (cbcRef);
	return PKCS7_ERROR_CALLBACK;
    }

	(void) PGPFreeCBCContext (cbcRef);

    /* remove PKCS7 padding in last block */
    pad = *(*msg + *msgLen - 1);
    if (pad < 1 || pad > blockSize)
	return PKCS7_ERROR_CALLBACK; /* invalid pad, decryption failed */
    *msgLen -= pad;

    return PKCS7_OK;
}
